{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NpJd3dlOCStH"
      },
      "source": [
        "\u003ca href=\"https://colab.research.google.com/github/magenta/ddsp/blob/master/ddsp/colab/demos/train_autoencoder.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\u003c/a\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hMqWDc_m6rUC"
      },
      "source": [
        "\n",
        "##### Copyright 2020 Google LLC.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VNhgka4UKNjf"
      },
      "outputs": [],
      "source": [
        "# Copyright 2020 Google LLC. All Rights Reserved.\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     http://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License.\n",
        "# =============================================================================="
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SpXo6phTiOQM"
      },
      "source": [
        "# Train a DDSP Autoencoder on GPU\n",
        "\n",
        "This notebook demonstrates how to install the DDSP library and train it for synthesis based on your own data using our command-line scripts. If run inside of Colab, it will automatically use a free Google Cloud GPU.\n",
        "\n",
        "At the end, you'll have a custom-trained checkpoint that you can download to use with the [DDSP Timbre Transfer Colab](https://colab.research.google.com/github/magenta/ddsp/blob/master/ddsp/colab/demos/timbre_transfer.ipynb).\n",
        "\n",
        "\u003cimg src=\"https://storage.googleapis.com/ddsp/additive_diagram/ddsp_autoencoder.png\" alt=\"DDSP Autoencoder figure\" width=\"700\"\u003e\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wXjcauVRB48S"
      },
      "source": [
        "**Note that we prefix bash commands with a `!` inside of Colab, but you would leave them out if running directly in a terminal.**"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Vn7CQ4GQizHy"
      },
      "source": [
        "## Install Dependencies\n",
        "\n",
        "First we install the required dependencies with `pip`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "id": "VxPuPR0j5Gs7"
      },
      "outputs": [],
      "source": [
        "!pip install -qU ddsp[data_preparation]==1.4.0\n",
        "\n",
        "# Initialize global path for using google drive. \n",
        "DRIVE_DIR = ''"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "w0fVn8yUJl_v"
      },
      "source": [
        "## Setup Google Drive (Optional, Recommeded)\n",
        "\n",
        "This notebook requires uploading audio and saving checkpoints. While you can do this with direct uploads / downloads, it is recommended to connect to your google drive account. This will enable faster file transfer, and regular saving of checkpoints so that you do not lose your work if the colab kernel restarts (common for training more than 12 hours). "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "L6MXUbL6KeMn"
      },
      "source": [
        "#### Login and mount your drive\n",
        "\n",
        "This will require an authentication code. You should then be able to see your drive in the file browser on the left panel."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "m33xuTjEKazJ"
      },
      "outputs": [],
      "source": [
        "from google.colab import drive\n",
        "drive.mount('/content/drive')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a4vmxpj1LC7m"
      },
      "source": [
        "#### Set your base directory\n",
        "* In drive, put all of the audio (.wav, .mp3) files with which you would like to train in a single folder.\n",
        " * Typically works well with 10-20 minutes of audio from a single monophonic source (also, one acoustic environment). \n",
        "* Use the file browser in the left panel to find a folder with your audio, right-click **\"Copy Path\", paste below**, and run the cell."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "A0bK6P9DMBTb"
      },
      "outputs": [],
      "source": [
        "#@markdown (ex. `/content/drive/My Drive/...`) Leave blank to skip loading from Drive.\n",
        "DRIVE_DIR = '' #@param {type: \"string\"}\n",
        "\n",
        "import os\n",
        "assert os.path.exists(DRIVE_DIR)\n",
        "print('Drive Folder Exists:', DRIVE_DIR)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FELlizMtIxCH"
      },
      "source": [
        "## Make directories to save model and data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Qd22WxEQI3FV"
      },
      "outputs": [],
      "source": [
        "AUDIO_DIR = 'data/audio'\n",
        "AUDIO_FILEPATTERN = AUDIO_DIR + '/*'\n",
        "!mkdir -p $AUDIO_DIR\n",
        "\n",
        "if DRIVE_DIR:\n",
        "  SAVE_DIR = os.path.join(DRIVE_DIR, 'ddsp-solo-instrument')\n",
        "else:\n",
        "  SAVE_DIR = '/content/models/ddsp-solo-instrument'\n",
        "!mkdir -p \"$SAVE_DIR\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fb4YD8woYD1H"
      },
      "source": [
        "## Prepare Dataset\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uNhH7nEbX2db"
      },
      "source": [
        "#### Upload training audio\n",
        "\n",
        "Upload audio files to use for training your model. Uses `DRIVE_DIR` if connected to drive, otherwise prompts local upload."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "itVKEzF6m3rY"
      },
      "outputs": [],
      "source": [
        "import glob\n",
        "import os\n",
        "from ddsp.colab import colab_utils\n",
        "\n",
        "if DRIVE_DIR:\n",
        "  mp3_files = glob.glob(os.path.join(DRIVE_DIR, '*.mp3'))\n",
        "  wav_files = glob.glob(os.path.join(DRIVE_DIR, '*.wav'))\n",
        "  audio_files = mp3_files + wav_files\n",
        "else:\n",
        "  audio_files, _ = colab_utils.upload()\n",
        "\n",
        "for fname in audio_files:\n",
        "  target_name = os.path.join(AUDIO_DIR, \n",
        "                             os.path.basename(fname).replace(' ', '_'))\n",
        "  print('Copying {} to {}'.format(fname, target_name))\n",
        "  !cp \"$fname\" $target_name"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "g_XVFoN2YOat"
      },
      "source": [
        "### Preprocess raw audio into TFRecord dataset\n",
        "\n",
        "We need to do some preprocessing on the raw audio you uploaded to get it into the correct format for training. This involves turning the full audio into short (4-second) examples, inferring the fundamental frequency (or \"pitch\") with [CREPE](http://github.com/marl/crepe), and computing the loudness. These features will then be stored in a sharded [TFRecord](https://www.tensorflow.org/tutorials/load_data/tfrecord) file for easier loading. Depending on the amount of input audio, this process usually takes a few minutes.\n",
        "\n",
        "* (Optional) Transfer dataset from drive. If you've already created a dataset, from a previous run, this cell will skip the dataset creation step and copy the dataset from `$DRIVE_DIR/data` "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MsnkAHyHVrCW"
      },
      "outputs": [],
      "source": [
        "import glob\n",
        "import os\n",
        "\n",
        "TRAIN_TFRECORD = 'data/train.tfrecord'\n",
        "TRAIN_TFRECORD_FILEPATTERN = TRAIN_TFRECORD + '*'\n",
        "\n",
        "# Copy dataset from drive if dataset has already been created.\n",
        "drive_data_dir = os.path.join(DRIVE_DIR, 'data') \n",
        "drive_dataset_files = glob.glob(drive_data_dir + '/*')\n",
        "\n",
        "if DRIVE_DIR and len(drive_dataset_files) \u003e 0:\n",
        "  !cp \"$drive_data_dir\"/* data/\n",
        "\n",
        "else:\n",
        "  # Make a new dataset.\n",
        "  if not glob.glob(AUDIO_FILEPATTERN):\n",
        "    raise ValueError('No audio files found. Please use the previous cell to '\n",
        "                    'upload.')\n",
        "\n",
        "  !ddsp_prepare_tfrecord \\\n",
        "    --input_audio_filepatterns=$AUDIO_FILEPATTERN \\\n",
        "    --output_tfrecord_path=$TRAIN_TFRECORD \\\n",
        "    --num_shards=10 \\\n",
        "    --alsologtostderr\n",
        "\n",
        "  # Copy dataset to drive for safe-keeping.\n",
        "  if DRIVE_DIR:\n",
        "    !mkdir \"$drive_data_dir\"/\n",
        "    print('Saving to {}'.format(drive_data_dir))\n",
        "    !cp $TRAIN_TFRECORD_FILEPATTERN \"$drive_data_dir\"/"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d4toX-D-AYZL"
      },
      "source": [
        "### Save dataset statistics for timbre transfer\n",
        "\n",
        "Quantile normalization helps match loudness of timbre transfer inputs to the \n",
        "loudness of the dataset, so let's calculate it here and save in a pickle file."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Bp_c8P0xApY6"
      },
      "outputs": [],
      "source": [
        "from ddsp.colab import colab_utils\n",
        "import ddsp.training\n",
        "\n",
        "data_provider = ddsp.training.data.TFRecordProvider(TRAIN_TFRECORD_FILEPATTERN)\n",
        "dataset = data_provider.get_dataset(shuffle=False)\n",
        "PICKLE_FILE_PATH = os.path.join(SAVE_DIR, 'dataset_statistics.pkl')\n",
        "\n",
        "_ = colab_utils.save_dataset_statistics(data_provider, PICKLE_FILE_PATH, batch_size=1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nIsq0HrzbOF7"
      },
      "source": [
        "Let's load the dataset in the `ddsp` library and have a look at one of the examples."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dA-FOmRgYdpZ"
      },
      "outputs": [],
      "source": [
        "from ddsp.colab import colab_utils\n",
        "import ddsp.training\n",
        "from matplotlib import pyplot as plt\n",
        "import numpy as np\n",
        "\n",
        "data_provider = ddsp.training.data.TFRecordProvider(TRAIN_TFRECORD_FILEPATTERN)\n",
        "dataset = data_provider.get_dataset(shuffle=False)\n",
        "\n",
        "try:\n",
        "  ex = next(iter(dataset))\n",
        "except StopIteration:\n",
        "  raise ValueError(\n",
        "      'TFRecord contains no examples. Please try re-running the pipeline with '\n",
        "      'different audio file(s).')\n",
        "\n",
        "colab_utils.specplot(ex['audio'])\n",
        "colab_utils.play(ex['audio'])\n",
        "\n",
        "f, ax = plt.subplots(3, 1, figsize=(14, 4))\n",
        "x = np.linspace(0, 4.0, 1000)\n",
        "ax[0].set_ylabel('loudness_db')\n",
        "ax[0].plot(x, ex['loudness_db'])\n",
        "ax[1].set_ylabel('F0_Hz')\n",
        "ax[1].set_xlabel('seconds')\n",
        "ax[1].plot(x, ex['f0_hz'])\n",
        "ax[2].set_ylabel('F0_confidence')\n",
        "ax[2].set_xlabel('seconds')\n",
        "ax[2].plot(x, ex['f0_confidence'])\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9gvXBa7PbuyY"
      },
      "source": [
        "## Train Model\n",
        "\n",
        "We will now train a \"solo instrument\" model. This means the model is conditioned only on the fundamental frequency (f0) and loudness with no instrument ID or latent timbre feature. If you uploaded audio of multiple instruemnts, the neural network you train will attempt to model all timbres, but will likely associate certain timbres with different f0 and loudness conditions. "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YpwQkSIKjEMZ"
      },
      "source": [
        "First, let's start up a [TensorBoard](https://www.tensorflow.org/tensorboard) to monitor our loss as training proceeds. \n",
        "\n",
        "Initially, TensorBoard will report `No dashboards are active for the current data set.`, but once training begins, the dashboards should appear."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "u2lx7yJneUXT"
      },
      "outputs": [],
      "source": [
        "%reload_ext tensorboard\n",
        "import tensorboard as tb\n",
        "tb.notebook.start('--logdir \"{}\"'.format(SAVE_DIR))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fT-8Koyvj46w"
      },
      "source": [
        "### We will now begin training. \n",
        "\n",
        "Note that we specify [gin configuration](https://github.com/google/gin-config) files for the both the model architecture ([solo_instrument.gin](TODO)) and the dataset ([tfrecord.gin](TODO)), which are both predefined in the library. You could also create your own. We then override some of the spefic params for `batch_size` (which is defined in in the model gin file) and the tfrecord path (which is defined in the dataset file). \n",
        "\n",
        "### Training Notes:\n",
        "* Models typically perform well when the loss drops to the range of ~4.5-5.0.\n",
        "* Depending on the dataset this can take anywhere from 5k-30k training steps usually.\n",
        "* The default is set to 30k, but you can stop training at any time, and for timbre transfer, it's best to stop before the loss drops too far below ~5.0 to avoid overfitting.\n",
        "* On the colab GPU, this can take from around 3-20 hours. \n",
        "* We **highly recommend** saving checkpoints directly to your drive account as colab will restart naturally after about 12 hours and you may lose all of your checkpoints.\n",
        "* By default, checkpoints will be saved every 300 steps with a maximum of 10 checkpoints (at ~60MB/checkpoint this is ~600MB). Feel free to adjust these numbers depending on the frequency of saves you would like and space on your drive.\n",
        "* If you're restarting a session and `DRIVE_DIR` points a directory that was previously used for training, training should resume at the last checkpoint."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "poKO-mZEGYXZ"
      },
      "outputs": [],
      "source": [
        "!ddsp_run \\\n",
        "  --mode=train \\\n",
        "  --alsologtostderr \\\n",
        "  --save_dir=\"$SAVE_DIR\" \\\n",
        "  --gin_file=models/solo_instrument.gin \\\n",
        "  --gin_file=datasets/tfrecord.gin \\\n",
        "  --gin_param=\"TFRecordProvider.file_pattern='$TRAIN_TFRECORD_FILEPATTERN'\" \\\n",
        "  --gin_param=\"batch_size=16\" \\\n",
        "  --gin_param=\"train_util.train.num_steps=30000\" \\\n",
        "  --gin_param=\"train_util.train.steps_per_save=300\" \\\n",
        "  --gin_param=\"trainers.Trainer.checkpoints_to_keep=10\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "V95qxVjFzWR6"
      },
      "source": [
        "## Resynthesis\n",
        "\n",
        "Check how well the model reconstructs the training data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OQ5PPDZVzgFR"
      },
      "outputs": [],
      "source": [
        "from ddsp.colab.colab_utils import play, specplot\n",
        "import ddsp.training\n",
        "import gin\n",
        "from matplotlib import pyplot as plt\n",
        "import numpy as np\n",
        "\n",
        "data_provider = ddsp.training.data.TFRecordProvider(TRAIN_TFRECORD_FILEPATTERN)\n",
        "dataset = data_provider.get_batch(batch_size=1, shuffle=False)\n",
        "\n",
        "try:\n",
        "  batch = next(iter(dataset))\n",
        "except OutOfRangeError:\n",
        "  raise ValueError(\n",
        "      'TFRecord contains no examples. Please try re-running the pipeline with '\n",
        "      'different audio file(s).')\n",
        "\n",
        "# Parse the gin config.\n",
        "gin_file = os.path.join(SAVE_DIR, 'operative_config-0.gin')\n",
        "gin.parse_config_file(gin_file)\n",
        "\n",
        "# Load model\n",
        "model = ddsp.training.models.Autoencoder()\n",
        "model.restore(SAVE_DIR)\n",
        "\n",
        "# Resynthesize audio.\n",
        "outputs = model(batch, training=False)\n",
        "audio_gen = model.get_audio_from_outputs(outputs)\n",
        "audio = batch['audio']\n",
        "\n",
        "print('Original Audio')\n",
        "specplot(audio)\n",
        "play(audio)\n",
        "\n",
        "print('Resynthesis')\n",
        "specplot(audio_gen)\n",
        "play(audio_gen)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZXM2ynLQ2Wl3"
      },
      "source": [
        "## Download Checkpoint\n",
        "\n",
        "Below you can download the final checkpoint. You are now ready to use it in the [DDSP Timbre Tranfer Colab](https://colab.research.google.com/github/magenta/ddsp/blob/master/ddsp/colab/demos/timbre_transfer.ipynb)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2WDiCyXP0tNE"
      },
      "outputs": [],
      "source": [
        "from ddsp.colab import colab_utils\n",
        "import tensorflow as tf\n",
        "import os\n",
        "\n",
        "CHECKPOINT_ZIP = 'my_solo_instrument.zip'\n",
        "latest_checkpoint_fname = os.path.basename(tf.train.latest_checkpoint(SAVE_DIR))\n",
        "!cd \"$SAVE_DIR\" \u0026\u0026 zip $CHECKPOINT_ZIP $latest_checkpoint_fname* operative_config-0.gin dataset_statistics.pkl\n",
        "!cp \"$SAVE_DIR/$CHECKPOINT_ZIP\" ./\n",
        "colab_utils.download(CHECKPOINT_ZIP)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [
        "hMqWDc_m6rUC"
      ],
      "name": "train_autoencoder.ipynb",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
